{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bTAeNLV3WdB0"
      },
      "source": [
        "本文涉及的jupter notebook在[篇章4代码库中](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1)。\n",
        "\n",
        "建议直接使用google colab notebook打开本教程，可以快速下载相关数据集和模型。\n",
        "如果您正在google的colab中打开这个notebook，您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bAWdZQdTWdB3"
      },
      "outputs": [],
      "source": [
        "# ! pip install datasets transformers \n",
        "# -i https://pypi.tuna.tsinghua.edu.cn/simple"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7FC_ZTXsWdB3"
      },
      "source": [
        "如果您是在本地机器上打开这个jupyter笔记本，请确保您的环境安装了上述库的最新版本。\n",
        "\n",
        "您可以在[这里](https://github.com/huggingface/transformers/tree/master/examples/language-modeling)找到这个jupyter笔记本的具体的python脚本文件，还可以通过分布式的方式使用多个gpu或tpu来微调您的模型。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BgQvrzh3WdB4"
      },
      "source": [
        "# 微调语言模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hqHxDcitWdB4"
      },
      "source": [
        "在当前jupyter笔记本中，我们将说明如何使用语言模型任务微调任意[🤗Transformers](https://github.com/huggingface/transformers) 模型。 \n",
        "\n",
        "本教程将涵盖两种类型的语言建模任务:\n",
        "\n",
        "+ 因果语言模型（Causal language modeling，CLM）：模型需要预测句子中的下一位置处的字符（类似BERT类模型的decoder和GPT，从左往右输入字符）。为了确保模型不作弊，模型会使用一个注意掩码防止模型看到之后的字符。例如，当模型试图预测句子中的i+1位置处的字符时，这个掩码将阻止它访问i位置之后的字符。\n",
        "\n",
        "![推理表示因果语言建模任务图片](./images/causal_language_modeling.png)\n",
        "\n",
        "+ 掩蔽语言建模（Masked language modeling，MLM）：模型需要恢复输入中被\"MASK\"掉的一些字符（BERT类模型的预训练任务）。这种方式模型可以看到整个句子，因此模型可以根据“\\[MASK\\]”标记之前和之后的字符来预测该位置被“\\[MASK\\]”之前的字符。\n",
        "\n",
        "![Widget inference representing the masked language modeling task](images/masked_language_modeling.png)\n",
        "\n",
        "接下来，我们将说明如何轻松地为每个任务加载和预处理数据集，以及如何使用“Trainer”API对模型进行微调。\n",
        "\n",
        "当然您也可以直接在分布式环境或TPU上运行该jupyter笔记本的python脚本版本，可以在[examples文件夹](https://github.com/huggingface/transformers/tree/master/examples)中找到。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GobLIFiRWdB5"
      },
      "source": [
        "## 准备数据"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tfkh562BWdB5"
      },
      "source": [
        "在接下来的这些任务中，我们将使用[Wikitext 2](https://huggingface.co/datasets/wikitext#data-instances)数据集作为示例。您可以通过🤗Datasets库加载该数据集："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vtDCQuMSWdB5"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IqMUxoJrWdB7"
      },
      "source": [
        "如果碰到以下错误：\n",
        "![request Error](images/request_error.png)\n",
        "\n",
        "解决方案:\n",
        "\n",
        "MAC用户: 在 ```/etc/hosts``` 文件中添加一行 ```199.232.68.133  raw.githubusercontent.com```\n",
        "\n",
        "Windowso用户: 在 ```C:\\Windows\\System32\\drivers\\etc\\hosts```  文件中添加一行 ```199.232.68.133  raw.githubusercontent.com```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wjl5FpYDWdB7"
      },
      "source": [
        "当然您也可以用公开在[hub](https://huggingface.co/datasets)上的任何数据集替换上面的数据集，或者使用您自己的文件。只需取消注释以下单元格，并将路径替换为将导致您的文件路径："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fncDchlaWdB7"
      },
      "outputs": [],
      "source": [
        "# datasets = load_dataset(\"text\", data_files={\"train\": path_to_train.txt, \"validation\": path_to_validation.txt}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ODIsscsTWdB8"
      },
      "source": [
        "您还可以从csv或JSON文件加载数据集，更多信息请参阅[完整文档](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files)。\n",
        "\n",
        "要访问一个数据中实际的元素，您需要先选择一个key，然后给出一个索引:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "H887z9CbWdB8",
        "outputId": "f9d05402-f99b-40da-b672-887e6a8c5597"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'text': ' The game \\'s battle system , the BliTZ system , is carried over directly from Valkyira Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \\' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific \" Potentials \" , skills unique to each character . They are divided into \" Personal Potential \" , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character , and \" Battle Potentials \" , which are grown throughout the game and always grant boons to a character . To learn Battle Potentials , each character has a unique \" Masters Table \" , a grid @-@ based skill table that can be used to acquire and link different skills . Characters also have Special Abilities that grant them temporary boosts on the battlefield : Kurt can activate \" Direct Command \" and move around the battlefield without depleting his Action Point gauge , the character Reila can shift into her \" Valkyria Form \" and become invincible , while Imca can target multiple enemy units with her heavy weapon . \\n'}"
            ]
          },
          "execution_count": 6,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "datasets[\"train\"][10]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y3vbv6yHWdB8"
      },
      "source": [
        "为了快速了解数据的结构，下面的函数将显示数据集中随机选取的一些示例。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "II9ha_LmWdB9"
      },
      "outputs": [],
      "source": [
        "from datasets import ClassLabel\n",
        "import random\n",
        "import pandas as pd\n",
        "from IPython.display import display, HTML\n",
        "\n",
        "def show_random_elements(dataset, num_examples=10):\n",
        "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
        "    picks = []\n",
        "    for _ in range(num_examples):\n",
        "        pick = random.randint(0, len(dataset)-1)\n",
        "        while pick in picks:\n",
        "            pick = random.randint(0, len(dataset)-1)\n",
        "        picks.append(pick)\n",
        "    \n",
        "    df = pd.DataFrame(dataset[picks])\n",
        "    for column, typ in dataset.features.items():\n",
        "        if isinstance(typ, ClassLabel):\n",
        "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
        "    display(HTML(df.to_html()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 427
        },
        "id": "LaCaYQyJWdB9",
        "outputId": "8fcf2a87-fa7c-46b1-bd03-26325ce69da9"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>text</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>MD 194D is the designation for an unnamed 0 @.@ 02 @-@ mile ( 0 @.@ 032 km ) connector between MD 194 and MD 853E , the old alignment that parallels the northbound direction of the modern highway south of Angell Road . \\n</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>My sense , as though of hemlock I had drunk , \\n</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td></td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>A mimed stage show , Thunderbirds : F.A.B. , has toured internationally and popularised a staccato style of movement known colloquially as the \" Thunderbirds walk \" . The production has periodically been revived as Thunderbirds : F.A.B. – The Next Generation . \\n</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td></td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td></td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>In his 1998 autobiography For the Love of the Game , Jordan wrote that he had been preparing for retirement as early as the summer of 1992 . The added exhaustion due to the Dream Team run in the 1992 Olympics solidified Jordan 's feelings about the game and his ever @-@ growing celebrity status . Jordan 's announcement sent shock waves throughout the NBA and appeared on the front pages of newspapers around the world . \\n</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>Research on new wildlife collars may be able to reduce human @-@ animal conflicts by predicting when and where predatory animals hunt . This can not only save human lives and the lives of their pets and livestock but also save these large predatory mammals that are important to the balance of ecosystems . \\n</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>\" Love Me Like You \" ( Christmas Mix ) – 3 : 29 \\n</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td></td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "show_random_elements(datasets[\"train\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LH0Uk_OsWdB9"
      },
      "source": [
        "正如我们所看到的，一些文本是维基百科文章的完整段落，而其他的只是标题或空行。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Nu5lPu8WdB-"
      },
      "source": [
        "## 因果语言模型（Causal Language Modeling，CLM）"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v7gOchUNWdB-"
      },
      "source": [
        "对于因果语言模型(CLM)，我们首先获取到数据集中的所有文本，并在它们被分词后将它们连接起来。然后，我们将在特定序列长度的例子中拆分它们。通过这种方式，模型将接收如下的连续文本块:\n",
        "\n",
        "```\n",
        "文本1\n",
        "```\n",
        "或\n",
        "```\n",
        "文本1结尾 [BOS_TOKEN] 文本2开头\n",
        "```\n",
        "\n",
        "取决于它们是否跨越数据集中的几个原始文本。标签将与输入相同，但向左移动。\n",
        "\n",
        "在本例中，我们将使用[`distilgpt2`](https://huggingface.co/distilgpt2) 模型。您同样也可以选择[这里](https://huggingface.co/models?filter=causal-lm)列出的任何一个checkpoint:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z37txOiBWdB-"
      },
      "outputs": [],
      "source": [
        "model_checkpoint = \"distilgpt2\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mk8BWvYWWdB-"
      },
      "source": [
        "为了用训练模型时使用的词汇对所有文本进行标记，我们必须下载一个预先训练过的分词器（Tokenizer）。而这些操作都可以由AutoTokenizer类完成:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mQwZ5UssWdB_"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer\n",
        "    \n",
        "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hAQJGvMxWdB_"
      },
      "source": [
        "我们现在可以对所有的文本调用分词器，该操作可以简单地使用来自Datasets库的map方法实现。首先，我们定义一个在文本上调用标记器的函数:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wxhKKMYgWdB_"
      },
      "outputs": [],
      "source": [
        "def tokenize_function(examples):\n",
        "    return tokenizer(examples[\"text\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FM_kMpbCWdB_"
      },
      "source": [
        "然后我们将它应用到datasets对象中的分词，使用```batch=True```和```4```个进程来加速预处理。而之后我们并不需要```text```列，所以将其舍弃。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rNb1U12YWdCA"
      },
      "outputs": [],
      "source": [
        "tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bc2niRZJWdCA"
      },
      "source": [
        "如果我们现在查看数据集的一个元素，我们会看到文本已经被模型所需的input_ids所取代:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HgC4UWv8WdCA",
        "outputId": "e3257089-88b6-4b15-fbe1-45073073ad3e"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
              " 'input_ids': [796, 569, 18354, 7496, 17740, 6711, 796, 220, 198]}"
            ]
          },
          "execution_count": 13,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "tokenized_datasets[\"train\"][1]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TpKE1TJXWdCA"
      },
      "source": [
        "下一步就有点小困难了：我们需要将所有文本连接在一起，然后将结果分割成特定`block_size`的小块。为此，我们将再次使用`map`方法，并使用选项`batch=True`。这个选项允许我们通过返回不同数量的样本来改变数据集中的样本数量。通过这种方式，我们可以从一批示例中创建新的示例。\n",
        "\n",
        "首先，我们需要获取预训练模型时所使用的最大长度。最大长度在这里设置为128，以防您的显存爆炸💥。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uEnLI7LJWdCB"
      },
      "outputs": [],
      "source": [
        "# block_size = tokenizer.model_max_length\n",
        "block_size = 128"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IwZu_FSjWdCB"
      },
      "source": [
        "然后我们编写预处理函数来对我们的文本进行分组:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OhhL0v2FWdCB"
      },
      "outputs": [],
      "source": [
        "def group_texts(examples):\n",
        "    # 拼接所有文本\n",
        "    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
        "    total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
        "    # 我们将余数对应的部分去掉。但如果模型支持的话，可以添加padding，您可以根据需要定制此部件。\n",
        "    total_length = (total_length // block_size) * block_size\n",
        "    # 通过max_len进行分割。\n",
        "    result = {\n",
        "        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
        "        for k, t in concatenated_examples.items()\n",
        "    }\n",
        "    result[\"labels\"] = result[\"input_ids\"].copy()\n",
        "    return result"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DpsALat7WdCC"
      },
      "source": [
        "首先注意，我们复制了标签的输入。\n",
        "\n",
        "这是因为🤗transformer库的模型默认向右移动，所以我们不需要手动操作。\n",
        "\n",
        "还要注意，在默认情况下，`map`方法将发送一批1,000个示例，由预处理函数处理。因此，在这里，我们将删除剩余部分，使连接的标记化文本每1000个示例为`block_size`的倍数。您可以通过传递更高的批处理大小来调整此行为(当然这也会被处理得更慢)。你也可以使用`multiprocessing`来加速预处理:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lmoi9YUZWdCC"
      },
      "outputs": [],
      "source": [
        "lm_datasets = tokenized_datasets.map(\n",
        "    group_texts,\n",
        "    batched=True,\n",
        "    batch_size=1000,\n",
        "    num_proc=4,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qPkZvXnCWdCD"
      },
      "source": [
        "现在我们可以检查数据集是否发生了变化：现在样本包含了`block_size`连续字符块，可能跨越了几个原始文本。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86
        },
        "id": "49I25iXJWdCD",
        "outputId": "f7f0364e-7ac0-44e0-aed1-d483b1dda631"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "' game and follows the \" Nameless \", a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit \" Calamaty Raven \". \\n The game began development in 2010, carrying over a large portion of the work done on Valkyria Chronicles II. While it retained the standard features of the series, it also underwent multiple adjustments, such as making the game more forgiving for series newcomers. Character designer Raita Honjou and composer Hitoshi Sakimoto both returned from previous entries, along with Valkyria Chronicles II director Takeshi Oz'"
            ]
          },
          "execution_count": 17,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "tokenizer.decode(lm_datasets[\"train\"][1][\"input_ids\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qVkDCaz5WdCD"
      },
      "source": [
        "既然数据已经清理完毕，我们就可以实例化我们的训练器了。我们将建立一个模型:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 66,
          "referenced_widgets": [
            "f9a94ec0a95a435b8f58cc67994099f7",
            "782ab18684dd4f36a403d180138e8f1d",
            "6d8e4de69163477891b6636d869f6e4e",
            "27a5fef432d24131b2678f9cf5906a4f",
            "36d6db8a50de4a3c886647f31b60b621",
            "969eb6b77105455cab015cb8574b1bd3",
            "a0d0fe9dec9b413fa80fa4d678dcb9c3",
            "7ecafb32516947aa9ace74da667d9665"
          ]
        },
        "id": "jm5DOjPOWdCD",
        "outputId": "5ec1e6e5-66ef-4033-fdc4-3ee3ee6e0dd9"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f9a94ec0a95a435b8f58cc67994099f7",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=352833716.0, style=ProgressStyle(descri…"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "from transformers import AutoModelForCausalLM\n",
        "model = AutoModelForCausalLM.from_pretrained(model_checkpoint)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sAFO2Y6_WdCE"
      },
      "source": [
        "检查torch版本"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "11XlW2ogWdCE",
        "outputId": "d8e8c62e-dfc0-43d1-b377-2b1796a52f56"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "True\n",
            "1.8.1+cu101\n"
          ]
        }
      ],
      "source": [
        "\n",
        "import importlib.util\n",
        "import importlib_metadata\n",
        "a = importlib.util.find_spec(\"torch\") is not None\n",
        "print(a)\n",
        "_torch_version = importlib_metadata.version(\"torch\")\n",
        "print(_torch_version)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y8_qK5i3WdCE"
      },
      "source": [
        "和一些`TrainingArguments`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WWhPVy82WdCE"
      },
      "outputs": [],
      "source": [
        "from transformers import Trainer, TrainingArguments"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qGz6BxoOWdCF"
      },
      "outputs": [],
      "source": [
        "training_args = TrainingArguments(\n",
        "    \"test-clm\",\n",
        "    evaluation_strategy = \"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    weight_decay=0.01,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "efdl12AIWdCF"
      },
      "source": [
        "我们把这些都传递给`Trainer`类:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1ukmE65zWdCF"
      },
      "outputs": [],
      "source": [
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=lm_datasets[\"train\"][:1000],\n",
        "    eval_dataset=lm_datasets[\"validation\"][:1000],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UDn2o1gSWdCF"
      },
      "source": [
        "然后就可以训练我们的模型🌶:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a55CO2xGWdCF"
      },
      "outputs": [],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PAFX3mCwWdCG"
      },
      "source": [
        "一旦训练完成，我们就可以评估我们的模型，得到它在验证集上的perplexity，如下所示:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g1A7eBP3WdCG"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "eval_results = trainer.evaluate()\n",
        "print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lJzcfmsVWdCG"
      },
      "source": [
        "## 掩蔽语言模型（Mask Language Modeling，MLM）"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sr5UHTPjWdCG"
      },
      "source": [
        "掩蔽语言模型(MLM)我们将使用相同的数据集预处理和以前一样用一个额外的步骤：\n",
        "\n",
        "我们将随机\"MASK\"一些字符(使用\"[MASK]\"进行替换)以及调整标签为只包含在\"[MASK]\"位置处的标签(因为我们不需要预测没有被\"MASK\"的字符)。\n",
        "\n",
        "在本例中，我们将使用[`distilroberta-base`](https://huggingface.co/distilroberta-base)模型。您同样也可以选择[这里](https://huggingface.co/models?filter=causal-lm)列出的任何一个checkpoint:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2X4qJPeXWdCG"
      },
      "outputs": [],
      "source": [
        "model_checkpoint = \"distilroberta-base\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GMDHywzqWdCH"
      },
      "source": [
        "我们可以像之前一样应用相同的分词器函数，我们只需要更新我们的分词器来使用刚刚选择的checkpoint:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sgIqOa4uWdCH"
      },
      "outputs": [],
      "source": [
        "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
        "tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O4BALsgJWdCH"
      },
      "source": [
        "像之前一样，我们把文本分组在一起，并把它们分成长度为`block_size`的样本。如果您的数据集由单独的句子组成，则可以跳过这一步。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d4jJo5X2WdCH"
      },
      "outputs": [],
      "source": [
        "lm_datasets = tokenized_datasets.map(\n",
        "    group_texts,\n",
        "    batched=True,\n",
        "    batch_size=1000,\n",
        "    num_proc=4,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2wbjmPZQWdCI"
      },
      "source": [
        "剩下的和我们之前的做法非常相似，只有两个例外。首先我们使用一个适合掩蔽语言模型的模型:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LkJuOG4oWdCI"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoModelForMaskedLM\n",
        "model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MVKOscEdWdCI"
      },
      "source": [
        "其次，我们使用一个特殊的data_collator。data_collator是一个函数，负责获取样本并将它们批处理成张量。\n",
        "\n",
        "在前面的例子中，我们没有什么特殊的事情要做，所以我们只使用这个参数的默认值。这里我们要做随机\"MASK\"。\n",
        "\n",
        "我们可以将其作为预处理步骤(`tokenizer`)进行处理，但在每个阶段，字符总是以相同的方式被掩盖。通过在data_collator中执行这一步，我们可以确保每次检查数据时都以新的方式完成随机掩蔽。\n",
        "\n",
        "为了实现掩蔽，`Transformers`为掩蔽语言模型提供了一个`DataCollatorForLanguageModeling`。我们可以调整掩蔽的概率:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A-k8wJK7WdCI"
      },
      "outputs": [],
      "source": [
        "from transformers import DataCollatorForLanguageModeling\n",
        "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m83JcPGyWdCI"
      },
      "source": [
        "然后我们要把所有的东西交给trainer，然后开始训练:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I12S2ZQxWdCJ"
      },
      "outputs": [],
      "source": [
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=lm_datasets[\"train\"][:1000],\n",
        "    eval_dataset=lm_datasets[\"validation\"][:100],\n",
        "    data_collator=data_collator,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u_4PHx1CWdCJ"
      },
      "outputs": [],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MDKbrOmzWdCJ"
      },
      "source": [
        "像以前一样，我们可以在验证集上评估我们的模型。\n",
        "\n",
        "与CLM目标相比，困惑度要低得多，因为对于MLM目标，我们只需要对隐藏的令牌(在这里占总数的15%)进行预测，同时可以访问其余的令牌。\n",
        "\n",
        "因此，对于模型来说，这是一项更容易的任务。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "60hUa-W5WdCJ"
      },
      "outputs": [],
      "source": [
        "eval_results = trainer.evaluate()\n",
        "print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uPa5UTWWWdCK"
      },
      "outputs": [],
      "source": [
        "不要忘记将你的模型[上传](https://huggingface.co/transformers/model_sharing.html)到[🤗 模型中心](https://huggingface.co/models)。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "4.5-生成任务-语言模型",
      "provenance": []
    },
    "interpreter": {
      "hash": "3bfce0b4c492a35815b5705a19fe374a7eea0baaa08b34d90450caf1fe9ce20b"
    },
    "kernelspec": {
      "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": ""
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "27a5fef432d24131b2678f9cf5906a4f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7ecafb32516947aa9ace74da667d9665",
            "placeholder": "​",
            "style": "IPY_MODEL_a0d0fe9dec9b413fa80fa4d678dcb9c3",
            "value": " 353M/353M [00:11&lt;00:00, 31.6MB/s]"
          }
        },
        "36d6db8a50de4a3c886647f31b60b621": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": "initial"
          }
        },
        "6d8e4de69163477891b6636d869f6e4e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "Downloading: 100%",
            "description_tooltip": null,
            "layout": "IPY_MODEL_969eb6b77105455cab015cb8574b1bd3",
            "max": 352833716,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_36d6db8a50de4a3c886647f31b60b621",
            "value": 352833716
          }
        },
        "782ab18684dd4f36a403d180138e8f1d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7ecafb32516947aa9ace74da667d9665": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "969eb6b77105455cab015cb8574b1bd3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a0d0fe9dec9b413fa80fa4d678dcb9c3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f9a94ec0a95a435b8f58cc67994099f7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_6d8e4de69163477891b6636d869f6e4e",
              "IPY_MODEL_27a5fef432d24131b2678f9cf5906a4f"
            ],
            "layout": "IPY_MODEL_782ab18684dd4f36a403d180138e8f1d"
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}